Conversation
yeandy
left a comment
There was a problem hiding this comment.
I'm wondering how many of these files we need?
primus/backends/maxtext/input_pipeline/_hf_data_processing.pyprimus/backends/maxtext/input_pipeline/custom_packed_batch.py(I see this is deleted)primus/backends/maxtext/layers/attention_op.pyprimus/backends/maxtext/layers/attentions.py(I see this is deleted)primus/backends/maxtext/metric_logger.pyprimus/backends/maxtext/train.pyprimus/backends/maxtext/train_utils.py
I think they were added in the past for the purposes of patching. @amd-fuyuajin do you know if these are getting patched into the MaxText codebase when you run the training? Even if it is, it might be the same code as what is found in rocm/jax-training:maxtext-v26.1 actually. @llying-001 might know best.
I updated these files in the Primus repo to stay aligned with the yeandy/update-patches-scaling-patch-v2-checkpoint-restore branch in ROCm/maxtext. |
- Add timestamp to log filenames to prevent overwriting across runs - Move tee logging outside the inline script to capture consolidated multi-node output in a single log file - Make --nodelist conditional via NODE_LIST env variable
- set TF_CPP_MIN_LOG_LEVEL=2. Without this setting, error occurs at the end when all training steps complete. - XLA_FLAGS is case sensitive. Corrected a few values.
- detect backend framework in `primus-cli-direct.sh`. Install JAX dependencies - If using AINIC (setting USING_AINIC=1), `03_enable_ainic.sh` will run. The `LD_LIBRARY_PATH` is modified to make sure libraries are correctly loaded for JAX/MaxText. - Set XLA_PYTHON_CLIENT_MEM_FRACTION=.93 to avoid HSA_STATUS_ERROR_OUT_OF_RESOURCES error during multi-node training - Corrected some XLA_FLAGS. It is case sensitive. Values `true` and `false` do not need to be capitalized. - set TF_CPP_MIN_LOG_LEVEL=2 to suppress the error messages at the end of JAX/MaxText training Here is an example to launch JAX/MaxText traing on two nodes. `./primus-cli --config runner/maxtext-test.yaml slurm srun -N 2 -- train pretrain --config examples/maxtext/configs/MI355X/llama2_7B-pretrain.yaml`
Problem: when apt install linux-headers-"$(uname -r)", it was resolved to wrong version number on some nodes, and caused "package not found" error. Solution: remove it from the package install list. It does not affect the performance.
1. added examples for using AINIC in training 2. added more examples for running preflight 3. updated arguments format for benchmark gemm command. The script was changed, but document was not updated.
2e31891 to
095b267
Compare
There was a problem hiding this comment.
Pull request overview
This PR adds comprehensive support for JAX/MaxText backend testing and multi-node training capabilities, including AINIC network integration, improved checkpointing, and various model architecture enhancements.
Changes:
- Updated MaxText submodule to a newer commit
- Added AINIC configuration support with proper environment variable setup and library path ordering
- Enhanced MaxText backend with improved checkpointing, attention mechanisms, and decoder layer implementations
- Refactored dependency installation to detect framework type and install appropriate requirements
Reviewed changes
Copilot reviewed 34 out of 35 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| third_party/maxtext | Updated MaxText submodule reference to newer commit |
| runner/use_ainic.yaml | New configuration file for AINIC network setup with container options |
| runner/primus-cli-direct.sh | Added framework detection logic to install correct dependencies (JAX vs PyTorch) |
| runner/helpers/hooks/train/pretrain/maxtext/prepare.py | Removed problematic linux-headers package, adjusted memory limits and XLA flags |
| runner/helpers/hooks/03_enable_ainic.sh | Fixed LD_LIBRARY_PATH ordering to append instead of prepend paths |
| runner/.primus.yaml | Uncommented InfiniBand device for AINIC support |
| requirements-jax.txt | Simplified to core dependencies only |
| primus/pretrain.py | Enhanced MaxText path detection to support src subdirectory |
| primus/modules/trainer/maxtext/pre_trainer.py | Extended patching to include initialization, checkpointing, config types, and decoder layers |
| primus/configs/modules/maxtext/trainer_base.yaml | Updated configuration with new parameters and removed deprecated options |
| primus/configs/models/maxtext/llama3.1_405B.yaml | New model configuration for Llama 3.1 405B |
| primus/backends/maxtext/train_utils.py | Refactored emergency checkpoint logic and updated to use max_num_checkpoints_to_keep |
| primus/backends/maxtext/train.py | Major refactor with barrier synchronization, improved error handling, and new training features |
| primus/backends/maxtext/metric_logger.py | Updated to use MetadataKey enum constants |
| primus/backends/maxtext/max_utils.py | Added JAX distributed initialization functions for GPU/CPU/TPU |
| primus/backends/maxtext/layers/moe.py | Updated MoE layer to pass bias parameters |
| primus/backends/maxtext/layers/mixtral.py | New Primus-specific Mixtral decoder layer implementation |
| primus/backends/maxtext/layers/mistral.py | New Primus-specific Mistral decoder layer implementation |
| primus/backends/maxtext/layers/llama2.py | New Primus-specific Llama2 decoder layer implementation |
| primus/backends/maxtext/layers/gemma2.py | New Primus-specific Gemma2 decoder layer implementation |
| primus/backends/maxtext/layers/gemma.py | New Primus-specific Gemma decoder layer implementation |
| primus/backends/maxtext/layers/attentions.py | Removed entire attention implementation file |
| primus/backends/maxtext/layers/attention_op.py | Enhanced CUDNN Flash Attention with packing and context parallelism support |
| primus/backends/maxtext/input_pipeline/custom_packed_batch.py | Removed custom packing implementation |
| primus/backends/maxtext/input_pipeline/_hf_data_processing.py | Updated to use grain's native packing and added instruction format conversion |
| primus/backends/maxtext/configs/types.py | New Primus-specific MaxText config with WandB and Turbo support |
| primus/backends/maxtext/checkpointing.py | Added comprehensive checkpoint loading logic with single replica support |
| examples/run_slurm_pretrain.sh | Added NODE_LIST support and timestamped log files |
| examples/run_pretrain.sh | Reorganized AINIC configuration and updated XLA flags |
| examples/run_local_pretrain.sh | Updated default Docker image to maxtext-v26.1 |
| examples/maxtext/configs/MI355X/mixtral_8x7B-pretrain.yaml | Reduced batch size from 12 to 11 |
| examples/maxtext/configs/MI355X/llama3.1_405B-pretrain.yaml | New training configuration for Llama 3.1 405B model |
| examples/maxtext/configs/MI300X/mixtral_8x7B-pretrain.yaml | Updated remat policy |
| docs/cli/PRIMUS-CLI-GUIDE.md | Updated documentation with AINIC configuration examples and corrected command syntax |
Comments suppressed due to low confidence (4)
runner/primus-cli-direct.sh:1
- Array index arithmetic should use proper bash syntax. The expression
$((i+1))correctly increments i, but when used inside array subscript it should be written as${args[i+1]}without the extra parentheses, or the current form needs validation that i+1 is within array bounds before access.
runner/primus-cli-direct.sh:1 - Python code embedded in bash script should properly close file handles. The
open('$cfg_path')should be wrapped in a context manager usingwith open('$cfg_path') as f: cfg = yaml.safe_load(f)to ensure the file is properly closed even if an exception occurs.
primus/backends/maxtext/max_utils.py:1 - Operator precedence issue: the condition mixes
orandandwithout parentheses. Due to operator precedence, this evaluates as(self.wandb_save_dir is None) or (self.wandb_save_dir == '' and self.base_output_directory), which may not be the intended logic. Add explicit parentheses:if (self.wandb_save_dir is None or self.wandb_save_dir == '') and self.base_output_directory:
###############################################################################
primus/backends/maxtext/max_utils.py:1
- Same operator precedence issue as above. Should be:
if (self.wandb_exp_name is None or self.wandb_exp_name == '') and self.run_name:
###############################################################################
accept copilot commit suggestion Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
…le model override args set
Uh oh!
There was an error while loading. Please reload this page.